import torch
import torch.nn as nn

net = nn.Sequential(
    nn.Linear(4, 64),
    nn.ReLU(),
    nn.Linear(64, 128),
    nn.ReLU(),
    nn.Linear(128, 2)
)

# 1. print()
print(net)

# 2. summary
from torchsummary import summary

summary(net, input_size=(4,), batch_size=10, device='cpu')

# 3. netron
# 3.1 终端输入 netron 到网址上查看  (直接查看，无连线)

# 3.2 将模型转换成脚本，再保存，查看有连线

script_model = torch.jit.script(net)
torch.jit.save(script_model, 'script_model.pth')

# 3.3 转换为通用格式 ONNX
x = torch.rand(10, 4)
torch.onnx.export(net, (x,), 'net.onnx')